package fieldmaskutil

import (
	"fmt"
	"strings"

	"google.golang.org/protobuf/proto"
	"google.golang.org/protobuf/reflect/protoreflect"
)

// NestedMask represents a field mask as a recursive map.
type NestedMask map[string]NestedMask

// NestedMaskFromPaths creates an instance of NestedMask for the given paths.
//
// For example ["foo.bar", "foo.baz"] becomes {"foo": {"bar": nil, "baz": nil}}.
func NestedMaskFromPaths(paths []string) NestedMask {
	var add func(path string, fm NestedMask)
	add = func(path string, mask NestedMask) {
		if len(path) == 0 {
			// Invalid input.
			return
		}
		dotIdx := strings.IndexRune(path, '.')
		if dotIdx == -1 {
			mask[path] = nil
		} else {
			field := path[:dotIdx]
			if len(field) == 0 {
				// Invalid input.
				return
			}
			rest := path[dotIdx+1:]
			nested := mask[field]
			if nested == nil {
				nested = make(NestedMask)
				mask[field] = nested
			}
			add(rest, nested)
		}
	}

	mask := make(NestedMask)
	for _, p := range paths {
		add(p, mask)
	}

	return mask
}

// Filter keeps the msg fields that are listed in the paths and clears all the rest.
//
// If the mask is empty then all the fields are kept.
// Paths are assumed to be valid and normalized otherwise the function may panic.
// See google.golang.org/protobuf/types/known/fieldmaskpb for details.
func (mask NestedMask) Filter(msg proto.Message) {
	if len(mask) == 0 {
		return
	}

	rft := msg.ProtoReflect()
	rft.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
		m, ok := mask[string(fd.Name())]
		if ok {
			if len(m) == 0 {
				return true
			}

			if fd.IsMap() {
				xmap := rft.Get(fd).Map()
				xmap.Range(func(mk protoreflect.MapKey, mv protoreflect.Value) bool {
					if mi, ok := m[mk.String()]; ok {
						if i, ok := mv.Interface().(protoreflect.Message); ok && len(mi) > 0 {
							mi.Filter(i.Interface())
						}
					} else {
						xmap.Clear(mk)
					}

					return true
				})
			} else if fd.IsList() {
				list := rft.Get(fd).List()
				for i := 0; i < list.Len(); i++ {
					m.Filter(list.Get(i).Message().Interface())
				}
			} else if fd.Kind() == protoreflect.MessageKind {
				m.Filter(rft.Get(fd).Message().Interface())
			}
		} else {
			rft.Clear(fd)
		}
		return true
	})
}

// Prune clears all the fields listed in paths from the given msg.
//
// All other fields are kept untouched. If the mask is empty no fields are cleared.
// This operation is the opposite of NestedMask.Filter.
// Paths are assumed to be valid and normalized otherwise the function may panic.
// See google.golang.org/protobuf/types/known/fieldmaskpb for details.
func (mask NestedMask) Prune(msg proto.Message) {
	if len(mask) == 0 {
		return
	}

	rft := msg.ProtoReflect()
	rft.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
		m, ok := mask[string(fd.Name())]
		if ok {
			if len(m) == 0 {
				rft.Clear(fd)
				return true
			}

			if fd.IsMap() {
				xmap := rft.Get(fd).Map()
				xmap.Range(func(mk protoreflect.MapKey, mv protoreflect.Value) bool {
					if mi, ok := m[mk.String()]; ok {
						if i, ok := mv.Interface().(protoreflect.Message); ok && len(mi) > 0 {
							mi.Prune(i.Interface())
						} else {
							xmap.Clear(mk)
						}
					}

					return true
				})
			} else if fd.IsList() {
				list := rft.Get(fd).List()
				for i := 0; i < list.Len(); i++ {
					m.Prune(list.Get(i).Message().Interface())
				}
			} else if fd.Kind() == protoreflect.MessageKind {
				m.Prune(rft.Get(fd).Message().Interface())
			}
		}
		return true
	})
}

// Overwrite overwrites all the fields listed in paths in the dest msg using values from src msg.
//
// All other fields are kept untouched. If the mask is empty, no fields are overwritten.
// Supports scalars, messages, repeated fields, and maps.
// If the parent of the field is nil message, the parent is initiated before overwriting the field
// If the field in src is empty value, the field in dest is cleared.
// Paths are assumed to be valid and normalized otherwise the function may panic.
func (mask NestedMask) Overwrite(src, dest proto.Message) {
	mask.overwrite(src.ProtoReflect(), dest.ProtoReflect())
}

// Validate checks if all paths are valid for specified message.
//
// Supports scalars, messages, repeated fields, and maps.
func (mask NestedMask) Validate(validationModel proto.Message) error {
	err := mask.validate("", validationModel.ProtoReflect())
	if err != nil {
		return fmt.Errorf("invalid mask: %s", err.Error())
	}

	return nil
}

func (mask NestedMask) overwrite(srcRft, destRft protoreflect.Message) {
	for srcFDName, submask := range mask {
		srcFD := srcRft.Descriptor().Fields().ByName(protoreflect.Name(srcFDName))
		srcVal := srcRft.Get(srcFD)
		if len(submask) == 0 {
			if isValid(srcFD, srcVal) && !srcVal.Equal(srcFD.Default()) {
				destRft.Set(srcFD, srcVal)
			} else {
				destRft.Clear(srcFD)
			}
		} else if srcFD.IsMap() && srcFD.Kind() == protoreflect.MessageKind {
			srcMap := srcRft.Get(srcFD).Map()
			destMap := destRft.Get(srcFD).Map()
			if !destMap.IsValid() {
				destRft.Set(srcFD, protoreflect.ValueOf(srcMap))
				destMap = destRft.Get(srcFD).Map()
			}
			srcMap.Range(func(mk protoreflect.MapKey, mv protoreflect.Value) bool {
				if mi, ok := submask[mk.String()]; ok {
					if i, ok := mv.Interface().(protoreflect.Message); ok && len(mi) > 0 {
						newVal := protoreflect.ValueOf(i.New())
						destMap.Set(mk, newVal)
						mi.overwrite(mv.Message(), newVal.Message())
					} else {

						destMap.Set(mk, mv)
					}
				} else {
					destMap.Clear(mk)
				}
				return true
			})
		} else if srcFD.IsList() && srcFD.Kind() == protoreflect.MessageKind {
			srcList := srcRft.Get(srcFD).List()
			destList := destRft.Mutable(srcFD).List()
			// Truncate anything in dest that exceeds the length of src
			if srcList.Len() < destList.Len() {
				destList.Truncate(srcList.Len())
			}
			for i := 0; i < srcList.Len(); i++ {
				srcListItem := srcList.Get(i)
				var destListItem protoreflect.Message
				if destList.Len() > i {
					// Overwrite existing items.
					destListItem = destList.Get(i).Message()
				} else {
					// Append new items to overwrite.
					destListItem = destList.AppendMutable().Message()
				}
				submask.overwrite(srcListItem.Message(), destListItem)
			}

		} else if srcFD.Kind() == protoreflect.MessageKind {
			// If the dest field is nil
			if !destRft.Get(srcFD).Message().IsValid() {
				destRft.Set(srcFD, protoreflect.ValueOf(destRft.Get(srcFD).Message().New()))
			}
			submask.overwrite(srcRft.Get(srcFD).Message(), destRft.Get(srcFD).Message())
		}
	}
}

func (mask NestedMask) validate(pathPrefix string, msg protoreflect.Message) error {
	for fieldName, submask := range mask {
		fieldDesc := msg.Descriptor().Fields().ByName(protoreflect.Name(fieldName))
		if fieldDesc == nil {
			return fmt.Errorf("unknown path: '%s'", fullPath(pathPrefix, fieldName))
		}

		if len(submask) == 0 {
			continue
		}

		var nestedMsg protoreflect.Message

		if fieldDesc.IsList() {
			listVal := msg.Get(fieldDesc).List().NewElement()

			var ok bool

			if nestedMsg, ok = listVal.Interface().(protoreflect.Message); !ok {
				return fmt.Errorf("'%s': list element isn't message kind", fullPath(pathPrefix, fieldName))
			}
		} else if fieldDesc.IsMap() {
			mapVal := msg.Get(fieldDesc).Map().NewValue()

			var ok bool

			if nestedMsg, ok = mapVal.Interface().(protoreflect.Message); !ok {
				return fmt.Errorf("'%s': map value isn't message kind", fullPath(pathPrefix, fieldName))
			}
		} else if fieldDesc.Kind() == protoreflect.MessageKind {
			nestedMsg = msg.Get(fieldDesc).Message()
		} else {
			return fmt.Errorf("'%s': can't get nested fields", fullPath(pathPrefix, fieldName))
		}

		err := submask.validate(fullPath(pathPrefix, fieldName), nestedMsg)
		if err != nil {
			return err
		}
	}

	return nil
}

func fullPath(pathPrefix, field string) string {
	if pathPrefix == "" {
		return field
	}

	return pathPrefix + "." + field
}

func isValid(fd protoreflect.FieldDescriptor, val protoreflect.Value) bool {
	if fd.IsMap() {
		return val.Map().IsValid()
	} else if fd.IsList() {
		return val.List().IsValid()
	} else if fd.Message() != nil {
		return val.Message().IsValid()
	}
	return true
}

// PathsFromFieldNumbers converts protobuf field numbers to field paths for the given message.
//
// This function takes a protobuf message and a list of field numbers, and
// returns a slice of field paths (field names) corresponding to those numbers.
//
// Field numbers that don't exist in the message descriptor are skipped.
//
// If no field numbers are provided, returns nil.
//
// Example:
//
//	// For a message with fields: name (field 1), age (field 2), address (field 3)
//	paths := PathsFromFieldNumbers(msg, 1, 2)
//	// Returns: ["name", "age"]
func PathsFromFieldNumbers(msg proto.Message, fieldNumbers ...int) []string {
	if len(fieldNumbers) == 0 {
		return nil
	}
	paths := make([]string, 0, len(fieldNumbers))
	descriptor := msg.ProtoReflect().Descriptor()
	for _, n := range fieldNumbers {
		field := descriptor.Fields().ByNumber(protoreflect.FieldNumber(n))
		if field != nil {
			paths = append(paths, field.TextName())
		}
	}
	return paths
}

func NilValuePaths(msg proto.Message, paths []string) []string {
	if len(paths) == 0 {
		return nil
	}

	var out []string

	rft := msg.ProtoReflect()
	for _, v := range paths {
		fd := rft.Descriptor().Fields().ByName(protoreflect.Name(v))
		if fd == nil {
			continue
		}

		if !rft.Has(fd) {
			out = append(out, v)
		}
	}

	return out
}
